from typing import Any, Optional, NamedTuple, Iterable, Callable
import haiku as hk
import numpy as np
import optax
import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp
from tqdm import tqdm
import tensorflow_probability.substrates.jax as tfp

tfd = tfp.distributions

DISCOUNT = 0.99

n_iters = 500

class State(NamedTuple):
    opt_state: optax.OptState
    params: hk.Params


def make_critic(hidden_layer_sizes = [256, 256]):

    def critic_mlp(act):
        network = hk.nets.MLP(
                    list(hidden_layer_sizes) + [1,],
                    w_init=hk.initializers.VarianceScaling(0.1, "fan_in", "uniform"),
                    activation=jax.nn.elu,
                    activate_final=False,
        )
        return network(act).squeeze()

    return critic_mlp


def train_model(network, x, y, jacobian_reg=False, sample_reg=False):
    key = jax.random.PRNGKey(0)

    param = network.init(key, x[0, None, :])

    opt = optax.adam(1e-4)
    opt_state = opt.init(param)

    state = State(opt_state, param)

    policy = tfd.Normal(loc=jnp.zeros((1,)), scale=0.1 * jnp.ones((1,)))

    jacobian = jax.jacfwd(network.apply, 1)

    def loss_fn(params, x, y, key):
        q = network.apply(params, x)
        se = ((y - q) ** 2).mean()
        loss = se
        
        if jacobian_reg:
          jac = jacobian(params, x)
          loss += (jac ** 2).sum()

        if sample_reg:
          sample = policy.sample(seed=key, sample_shape=x.shape) + x
          q_sample = network.apply(params, sample)
          loss += (1. - DISCOUNT) * q_sample.mean()

        return loss

    def step(state, x, y, key):
        params = state.params
        loss_value, grad = jax.value_and_grad(loss_fn, has_aux=False)(params, x, y, key)
        updates, opt_state = opt.update(grad, state.opt_state)
        params = optax.apply_updates(state.params, updates)
        return (
            State(
                opt_state,
                params,
            ),
            loss_value,
        )

    step = jax.jit(step)

    values = []
    for _ in tqdm(range(n_iters)):
        key, _ = jax.random.split(key)
        state, value = step(state, x, y, key)
        values += [value]

    def critic(x):
        return network.apply(state.params, x)

    return critic


def main():
  # define simple critic shaping setting
  target = 2 * jnp.ones((1, 1))
  state = jnp.zeros((1, 1))
  states = jnp.linspace(-5, 5, 200)[:, None]

  critic = network = hk.without_apply_rng(hk.transform(make_critic()))

  csil_critic = train_model(critic, state, target, jacobian_reg=True)
  iqlearn_critic = train_model(critic, state, target, sample_reg=True)

  plt.rc("text", usetex=True)
  plt.rc("font", family="serif", size=9)

  fig, axs = plt.subplots(1, 2, figsize=(6.6, 1.3))

  ax = axs[0]

  exp = {
  'Jacobian regularization': csil_critic,
  'sample regularization': iqlearn_critic,
  }

  axs[1].hlines(y=0, xmin=-5, xmax=5, color='k', linestyles='dotted', alpha=0.5)
  axs[1].vlines(x=0, ymin=-1, ymax=1, color='k', linestyles='dotted', alpha=0.5)

  for name, model in exp.items():
    qs = model(states)
    ax.plot(states, qs, label=name)
    jacobian = jax.jacfwd(model)
    dqda = jnp.einsum('bij->b', jacobian(states))
    axs[1].plot(states, dqda)

  ax.plot(state, target, 'kx', markersize=5)
  ax.legend(frameon=False)
  ax.set_ylim(0, 5)
  ax.set_ylabel("$Q$")
  axs[1].set_ylabel("$\\partial Q / \\partial a$")

  for ax in axs:
    ax.set_xlim(-5, 5)
    ax.set_xticklabels([])
    ax.set_xticks([])
    ax.set_yticklabels([])
    ax.set_yticks([])
    ax.set_xlabel("$a$")

  fig.tight_layout()
  fig.savefig("critic_reg.pdf", bbox_inches="tight")


if __name__ == "__main__":
  main()
  plt.show() 